import torch
import torchvision
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
from Trace import LossScaledTrace
import time
import copy

def MultiSGD(train_data, test_data, train_size = 6400, bs = 32, eps = 50):
    # N_train is the size of the whole MNIST train set, and train_size is the size of the selected train dataset.
    N_train = 60000
    start_time = time.time()
    num_epochs = eps
    #learning_rate = 0.05
    learning_rate = 0.1
    train_bs = bs
    eval_bs = 100
    test_bs = 100


    if torch.cuda.is_available():
        print("Working on GPU")
    else:
        print("Working on CPU")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Define the batch size
    eval_data = copy.deepcopy(train_data)
    loaders = {
        'train': DataLoader(train_data,
                                             batch_size=train_bs,
                                             shuffle=True,
                                             num_workers=8,
                                             pin_memory=True),
        'eval': DataLoader(eval_data,
                                             batch_size=eval_bs,
                                             shuffle=True,
                                             num_workers=8,
                                             pin_memory=True),
        'test': DataLoader(test_data,
                                            batch_size=test_bs,
                                            shuffle=True,
                                            num_workers=8,
                                            pin_memory=True),
    }


    ProductTraces1 = []
    Frobeniuses1 = []
    HessianTraces1 = []
    ProductTraces2 = []
    Frobeniuses2 = []
    HessianTraces2 = []
    ProductTraces3 = []
    Frobeniuses3 = []
    HessianTraces3 = []
    Epochs = []
    TrainLosses = []

    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Sequential(
                nn.Conv2d(
                    in_channels=1,
                    out_channels=16,
                    kernel_size=5,
                    stride=1,
                    padding=2,
                ),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(16, 32, 5, 1, 2),
                nn.ReLU(),
                nn.MaxPool2d(2),
            )
            # fully connected layer, output 10 classes
            self.out = nn.Linear(32 * 7 * 7, 10)
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
            x = x.view(x.size(0), -1)
            output = self.out(x)
            return output, x    # return x for visualization

    cnn = CNN()
    cnn = cnn.to(device)
    cnn = nn.DataParallel(cnn)
    print(cnn)

    loss_func = nn.CrossEntropyLoss()
    print(loss_func)

    optimizer = optim.SGD(cnn.parameters(), lr = learning_rate)
    print(optimizer)


    def train(num_epochs, cnn, loaders):
        #cnn.train()

        # Compute the number of parameters in the model
        total_params = sum(p.numel() for p in cnn.parameters())
        #print("The parameters are {}".format(cnn.parameters()))
        #print("The number of parameters is {}".format(total_params))
        # Train the model
        total_step = len(loaders['train'])

        for epoch in range(num_epochs):
            for i, (images, labels) in enumerate(loaders['train']):
                # gives batch data, normalize x when iterate train_loader
                b_x = Variable(images)  # batch x
                b_y = Variable(labels)  # batch y
                b_x = b_x.to(device)
                b_y = b_y.to(device)


                output = cnn(b_x)[0]
                loss = loss_func(output, b_y)

                # clear gradients for this training step
                optimizer.zero_grad()

                # backpropagation, compute gradients
                loss.backward()
                # apply gradients
                optimizer.step()

                if (i + 1) % 100 == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.14f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
                    # Print the loss-scaled gradient variance

            if (epoch + 1) % 10 == 0:
                # print(epoch + 1)
                modelcopy = copy.deepcopy(cnn)
                # print(LimitNeighborBV(test_model=modelcopy, inputs=x_train, labels=y_train, d=d,
                #                                N_train=N_train, radiuses=[10e-7]))
                # print(LimitNeighborLossScaledBV(test_model=modelcopy, inputs=x_train, labels=y_train, d=d,
                # N_train=N_train, radiuses=[10e-7]))


                ProductTrace, Frobenius, HessianTrace = LossScaledTrace(test_model=modelcopy,
                                                                        train_data=train_data,
                                                                        d=total_params,
                                                                        train_size=train_size,
                                                                        B=10000)
                ProductTraces1.append(ProductTrace)
                Frobeniuses1.append(Frobenius)
                HessianTraces1.append(HessianTrace)
                Epochs.append(epoch + 1)
                ProductTrace, Frobenius, HessianTrace = LossScaledTrace(test_model=modelcopy,
                                                                        train_data=train_data,
                                                                        d=total_params,
                                                                        train_size=train_size,
                                                                        B=10000)
                ProductTraces2.append(ProductTrace)
                Frobeniuses2.append(Frobenius)
                HessianTraces2.append(HessianTrace)
                ProductTrace, Frobenius, HessianTrace = LossScaledTrace(test_model=modelcopy,
                                                                        train_data=train_data,
                                                                        d=total_params,
                                                                        train_size=train_size,
                                                                        B=10000)
                ProductTraces3.append(ProductTrace)
                Frobeniuses3.append(Frobenius)
                HessianTraces3.append(HessianTrace)
            if (epoch + 1) % 1 == 0:
                # Compute the train loss
                modelcopy = copy.deepcopy(cnn)
                #modelcopy = modelcopy.to(device)
                #modelcopy = nn.DataParallel(modelcopy)
                modelcopy.eval()
                with torch.no_grad():
                    train_loss = 0
                    for j, (images, labels) in enumerate(loaders['eval']):
                        images = images.to(device)
                        labels = labels.to(device)
                        outputs = modelcopy(images)[0]
                        train_loss += loss_func(outputs, labels).item()
                    train_loss /= len(loaders['eval'])
                TrainLosses.append(train_loss)




    train(num_epochs, cnn, loaders)


    def test():
        # Test the model
        cnn.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            accuracy = 0
            for images, labels in loaders['test']:
                images = images.to(device)
                labels = labels.to(device)
                test_output, last_layer = cnn(images)
                pred_y = torch.max(test_output, 1)[1].data.squeeze()
                accuracy += (pred_y == labels).sum().item() / float(labels.size(0))
            print('SGD Test Accuracy of the model on the 10000 test images: %.4f' % (accuracy/test_bs))

    test()
    print(f'SGD RunTime: {time.time() - start_time:.2f}')

    return Epochs, ProductTraces1, Frobeniuses1, HessianTraces1, ProductTraces2, Frobeniuses2, HessianTraces2, ProductTraces3, Frobeniuses3, HessianTraces3, TrainLosses